{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "## Using Implicit Q learning (an offline RL algorithm) in Pearl.\n",
        "\n",
        "Here is a [better rendering](https://nbviewer.org/github/facebookresearch/Pearl/blob/main/tutorials/sequential_decision_making/Implicit_Q_learning.ipynb) of this notebook on [nbviewer](https://nbviewer.org/)\n",
        "\n",
        "- The purpose of this tutorial is to illustrate how users can use Pearl's implementation of Implicit Q-learning (IQL), an offline RL algorithm.\n",
        "\n",
        "- This example illustrates offline learning for continuous control using\n",
        "offline data collected from Open AI Gym's HalfCheetah environment.\n"
      ],
      "metadata": {
        "id": "PBWPlSEq_fBf"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "0Fa4tQSG_YoI"
      },
      "outputs": [],
      "source": [
        "%load_ext autoreload\n",
        "%autoreload 2"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Pearl Installation\n",
        "\n",
        "If you haven't installed Pearl, please make sure you install Pearl with the following cell. Otherwise, you can skip the cell below."
      ],
      "metadata": {
        "id": "VLL6NfNABgQv"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Pearl installation from github. This install also includes PyTorch, Gym and Matplotlib\n",
        "\n",
        "%pip uninstall Pearl -y\n",
        "%rm -rf Pearl\n",
        "!git clone https://github.com/facebookresearch/Pearl.git\n",
        "%cd Pearl\n",
        "%pip install .\n",
        "%cd .."
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gZyj_w4sBg7L",
        "outputId": "66f23017-ff4c-4127-bc09-3259f573a1c6"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[33mWARNING: Skipping Pearl as it is not installed.\u001b[0m\u001b[33m\n",
            "\u001b[0mCloning into 'Pearl'...\n",
            "remote: Enumerating objects: 5615, done.\u001b[K\n",
            "remote: Counting objects: 100% (1824/1824), done.\u001b[K\n",
            "remote: Compressing objects: 100% (613/613), done.\u001b[K\n",
            "remote: Total 5615 (delta 1388), reused 1563 (delta 1199), pack-reused 3791\u001b[K\n",
            "Receiving objects: 100% (5615/5615), 53.75 MiB | 21.57 MiB/s, done.\n",
            "Resolving deltas: 100% (3715/3715), done.\n",
            "/content/Pearl\n",
            "Processing /content/Pearl\n",
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Installing backend dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: gym in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (0.25.2)\n",
            "Collecting gymnasium[accept-rom-license,atari,mujoco] (from Pearl==0.1.0)\n",
            "  Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m953.9/953.9 kB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (1.25.2)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (3.7.1)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (1.5.3)\n",
            "Collecting parameterized (from Pearl==0.1.0)\n",
            "  Downloading parameterized-0.9.0-py2.py3-none-any.whl (20 kB)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (2.31.0)\n",
            "Collecting mujoco (from Pearl==0.1.0)\n",
            "  Downloading mujoco-3.1.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.4 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.4/5.4 MB\u001b[0m \u001b[31m74.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (2.2.1+cu121)\n",
            "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (0.17.1+cu121)\n",
            "Requirement already satisfied: torchaudio in /usr/local/lib/python3.10/dist-packages (from Pearl==0.1.0) (2.2.1+cu121)\n",
            "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from gym->Pearl==0.1.0) (2.2.1)\n",
            "Requirement already satisfied: gym-notices>=0.0.4 in /usr/local/lib/python3.10/dist-packages (from gym->Pearl==0.1.0) (0.0.8)\n",
            "Requirement already satisfied: typing-extensions>=4.3.0 in /usr/local/lib/python3.10/dist-packages (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) (4.10.0)\n",
            "Collecting farama-notifications>=0.0.1 (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0)\n",
            "  Downloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)\n",
            "Collecting autorom[accept-rom-license]~=0.4.2 (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0)\n",
            "  Downloading AutoROM-0.4.2-py3-none-any.whl (16 kB)\n",
            "Requirement already satisfied: imageio>=2.14.1 in /usr/local/lib/python3.10/dist-packages (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) (2.31.6)\n",
            "Collecting shimmy[atari]<1.0,>=0.1.0 (from gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0)\n",
            "  Downloading Shimmy-0.2.1-py3-none-any.whl (25 kB)\n",
            "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from mujoco->Pearl==0.1.0) (1.4.0)\n",
            "Requirement already satisfied: etils[epath] in /usr/local/lib/python3.10/dist-packages (from mujoco->Pearl==0.1.0) (1.7.0)\n",
            "Collecting glfw (from mujoco->Pearl==0.1.0)\n",
            "  Downloading glfw-2.7.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38-none-manylinux2014_x86_64.whl (211 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.8/211.8 kB\u001b[0m \u001b[31m26.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: pyopengl in /usr/local/lib/python3.10/dist-packages (from mujoco->Pearl==0.1.0) (3.1.7)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (1.2.0)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (0.12.1)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (4.50.0)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (1.4.5)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (24.0)\n",
            "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (9.4.0)\n",
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (3.1.2)\n",
            "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->Pearl==0.1.0) (2.8.2)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->Pearl==0.1.0) (2023.4)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->Pearl==0.1.0) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->Pearl==0.1.0) (3.6)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->Pearl==0.1.0) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->Pearl==0.1.0) (2024.2.2)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (3.13.1)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (1.12)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (3.2.1)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (3.1.3)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (2023.6.0)\n",
            "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->Pearl==0.1.0)\n",
            "  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m37.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->Pearl==0.1.0)\n",
            "  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m55.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->Pearl==0.1.0)\n",
            "  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m79.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch->Pearl==0.1.0)\n",
            "  Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m731.7/731.7 MB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-cublas-cu12==12.1.3.1 (from torch->Pearl==0.1.0)\n",
            "  Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch->Pearl==0.1.0)\n",
            "  Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch->Pearl==0.1.0)\n",
            "  Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m10.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch->Pearl==0.1.0)\n",
            "  Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch->Pearl==0.1.0)\n",
            "  Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-nccl-cu12==2.19.3 (from torch->Pearl==0.1.0)\n",
            "  Downloading nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl (166.0 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m166.0/166.0 MB\u001b[0m \u001b[31m7.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch->Pearl==0.1.0)\n",
            "  Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m14.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch->Pearl==0.1.0) (2.2.0)\n",
            "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch->Pearl==0.1.0)\n",
            "  Downloading nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m63.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: click in /usr/local/lib/python3.10/dist-packages (from autorom[accept-rom-license]~=0.4.2->gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) (8.1.7)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from autorom[accept-rom-license]~=0.4.2->gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0) (4.66.2)\n",
            "Collecting AutoROM.accept-rom-license (from autorom[accept-rom-license]~=0.4.2->gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0)\n",
            "  Downloading AutoROM.accept-rom-license-0.6.1.tar.gz (434 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m434.7/434.7 kB\u001b[0m \u001b[31m44.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->Pearl==0.1.0) (1.16.0)\n",
            "Collecting ale-py~=0.8.1 (from shimmy[atari]<1.0,>=0.1.0->gymnasium[accept-rom-license,atari,mujoco]->Pearl==0.1.0)\n",
            "  Downloading ale_py-0.8.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.7 MB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m81.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils[epath]->mujoco->Pearl==0.1.0) (6.3.2)\n",
            "Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils[epath]->mujoco->Pearl==0.1.0) (3.18.1)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->Pearl==0.1.0) (2.1.5)\n",
            "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->Pearl==0.1.0) (1.3.0)\n",
            "Building wheels for collected packages: Pearl, AutoROM.accept-rom-license\n",
            "  Building wheel for Pearl (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for Pearl: filename=Pearl-0.1.0-py3-none-any.whl size=211430 sha256=e4d88588b2376e35794984d9e88eebfdf7cfdeb772e9b8420e41193499eb4342\n",
            "  Stored in directory: /tmp/pip-ephem-wheel-cache-5nxlxw_t/wheels/83/80/1d/d9211ba70ee392341daf21a07252739e0cb2af9f95439a28cd\n",
            "  Building wheel for AutoROM.accept-rom-license (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for AutoROM.accept-rom-license: filename=AutoROM.accept_rom_license-0.6.1-py3-none-any.whl size=446659 sha256=15077f3d72b6dbbe7e9d1afa8e2148a5cc9e20ab6348782facfddb6ccc5807b7\n",
            "  Stored in directory: /root/.cache/pip/wheels/6b/1b/ef/a43ff1a2f1736d5711faa1ba4c1f61be1131b8899e6a057811\n",
            "Successfully built Pearl AutoROM.accept-rom-license\n",
            "Installing collected packages: glfw, farama-notifications, parameterized, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, gymnasium, ale-py, shimmy, nvidia-cusparse-cu12, nvidia-cudnn-cu12, AutoROM.accept-rom-license, autorom, nvidia-cusolver-cu12, mujoco, Pearl\n",
            "Successfully installed AutoROM.accept-rom-license-0.6.1 Pearl-0.1.0 ale-py-0.8.1 autorom-0.4.2 farama-notifications-0.0.4 glfw-2.7.0 gymnasium-0.29.1 mujoco-3.1.3 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.19.3 nvidia-nvjitlink-cu12-12.4.99 nvidia-nvtx-cu12-12.1.105 parameterized-0.9.0 shimmy-0.2.1\n",
            "/content\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import os\n",
        "import pickle\n",
        "import torch\n",
        "\n",
        "\n",
        "from pearl.pearl_agent import PearlAgent\n",
        "from pearl.utils.functional_utils.experimentation.set_seed import set_seed\n",
        "from pearl.utils.instantiations.environments.gym_environment import GymEnvironment\n",
        "from pearl.neural_networks.sequential_decision_making.actor_networks import VanillaContinuousActorNetwork\n",
        "from pearl.policy_learners.sequential_decision_making.implicit_q_learning import ImplicitQLearning\n",
        "\n",
        "from pearl.utils.functional_utils.experimentation.create_offline_data import (\n",
        "    get_data_collection_agent_returns,\n",
        ")\n",
        "\n",
        "from pearl.utils.functional_utils.train_and_eval.offline_learning_and_evaluation import (\n",
        "    get_offline_data_in_buffer,\n",
        "    offline_evaluation,\n",
        "    offline_learning,\n",
        ")\n",
        "\n",
        "set_seed(0)"
      ],
      "metadata": {
        "id": "YoD1X9oyBnzD"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Specify some environment details"
      ],
      "metadata": {
        "id": "I4CEVZn9Bv3g"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "experiment_seed = 100\n",
        "\n",
        "# We choose a continuous control environment from MuJoCo called 'Half-Cheetah'.\n",
        "env_name = \"HalfCheetah-v4\"\n",
        "env = GymEnvironment(env_name)\n",
        "action_space = env.action_space\n",
        "is_action_continuous = True"
      ],
      "metadata": {
        "id": "tciCqt6WBumg",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "0b495d4c-5ae0-4c6f-a50e-9e989acb7b90"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
            "  and should_run_async(code)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Download offline data\n",
        "\n",
        "- We have offline data at https://github.com/facebookresearch/Pearl/tree/gh-pages/data which will be used for this tutorial.\n",
        "\n",
        "- This dataset was collected using the replay buffer of an experiment where soft-actor critic (SAC) algorithm was used for policy learning, along with a large entropy parameter for exploration. Users can think of it as a 'medium replay' dataset in D4RL (https://github.com/Farama-Foundation/D4RL).\n",
        "\n",
        "- Note: We have not integrated Pearl with D4RL as it is being deprecated and a new library called Minari is being developed. We do plan to integrate Pearl with Minari at a later time."
      ],
      "metadata": {
        "id": "IykiuKhUCC3r"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import requests\n",
        "\n",
        "# Download the dataset of transition tuples from the github url and store in a local file\n",
        "url = \"https://github.com/facebookresearch/Pearl/raw/gh-pages/data/offline_rl_data/HalfCheetah/offline_raw_transitions_dict_small_2.pt\"\n",
        "filename = \"offline_raw_transitions_dict_small_2.pt\"    # local file with the dataset of transition tuples\n",
        "response = requests.get(url)\n",
        "with open(filename, \"wb\") as f:\n",
        "    f.write(response.content)"
      ],
      "metadata": {
        "id": "X4X4YspnDlUD"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "cwd = os.getcwd()   # current working directory\n",
        "data_path = cwd + '/' + filename    # change this in case you have a file with the dataset of transition tuples already stored in a local path.\n",
        "\n",
        "# The default device where offline data replay buffer is stored is cpu; see the `get_offline_data_in_buffer` for device management\n",
        "offline_data_replay_buffer = get_offline_data_in_buffer(\n",
        "    is_action_continuous=is_action_continuous,\n",
        "    url=None,\n",
        "    data_path=data_path,  # path to local file which contains the offline dataset\n",
        ")"
      ],
      "metadata": {
        "id": "SWGnW4-0cw0D"
      },
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Set up an IQL agent\n",
        "\n",
        "- Note that here `env` and `action_space` are for the HalfCheetah environment as set above.\n"
      ],
      "metadata": {
        "id": "eD0Q6NwQngxf"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "offline_agent = PearlAgent(\n",
        "    policy_learner=ImplicitQLearning(\n",
        "        state_dim=env.observation_space.shape[0],\n",
        "        action_space=action_space,\n",
        "        actor_hidden_dims=[256, 256],\n",
        "        critic_hidden_dims=[256, 256],\n",
        "        value_critic_hidden_dims=[256, 256],\n",
        "        actor_network_type=VanillaContinuousActorNetwork,\n",
        "        value_critic_learning_rate=1e-3,\n",
        "        actor_learning_rate=3e-4,\n",
        "        critic_learning_rate=1e-4,\n",
        "        critic_soft_update_tau=0.05,\n",
        "        training_rounds=2,\n",
        "        batch_size=256,\n",
        "        expectile=0.75,\n",
        "        temperature_advantage_weighted_regression=3,\n",
        "    ),\n",
        ")"
      ],
      "metadata": {
        "id": "2nHPHmFenfoD"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Offline learning"
      ],
      "metadata": {
        "id": "UJ56YQqbpQ-U"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Number of training epochs\n",
        "training_epochs = 20000\n",
        "\n",
        "# Use the `offline_learning` utility function in Pearl to train an offline RL agent using offline data\n",
        "offline_learning(\n",
        "    offline_agent=offline_agent,\n",
        "    data_buffer=offline_data_replay_buffer, # replay buffer created using the offline data\n",
        "    training_epochs=training_epochs,\n",
        "    seed=experiment_seed,\n",
        ")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "oPe4Sux2pA_H",
        "outputId": "a7143493-6f54-406f-ca6c-fc8baaba0a10"
      },
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "training epoch 0 training loss {'value_loss': 0.44493526220321655, 'actor_loss': 1.1561717987060547, 'critic_loss': 5.130976676940918}\n",
            "training epoch 500 training loss {'value_loss': 0.1409292221069336, 'actor_loss': 0.5989271998405457, 'critic_loss': 10.62370491027832}\n",
            "training epoch 1000 training loss {'value_loss': 0.5230052471160889, 'actor_loss': 0.9807262420654297, 'critic_loss': 24.416900634765625}\n",
            "training epoch 1500 training loss {'value_loss': 0.9530619382858276, 'actor_loss': 3.278916597366333, 'critic_loss': 29.79534149169922}\n",
            "training epoch 2000 training loss {'value_loss': 1.614856243133545, 'actor_loss': 6.346199035644531, 'critic_loss': 43.3765869140625}\n",
            "training epoch 2500 training loss {'value_loss': 2.005828619003296, 'actor_loss': 7.089072227478027, 'critic_loss': 59.126861572265625}\n",
            "training epoch 3000 training loss {'value_loss': 2.0187668800354004, 'actor_loss': 3.204172134399414, 'critic_loss': 48.85820770263672}\n",
            "training epoch 3500 training loss {'value_loss': 3.2361795902252197, 'actor_loss': 1.1803092956542969, 'critic_loss': 61.56980895996094}\n",
            "training epoch 4000 training loss {'value_loss': 2.8879613876342773, 'actor_loss': 6.797780990600586, 'critic_loss': 41.75172805786133}\n",
            "training epoch 4500 training loss {'value_loss': 3.4644596576690674, 'actor_loss': 5.647995471954346, 'critic_loss': 64.77790832519531}\n",
            "training epoch 5000 training loss {'value_loss': 3.648473024368286, 'actor_loss': 3.0125679969787598, 'critic_loss': 66.62677001953125}\n",
            "training epoch 5500 training loss {'value_loss': 3.801520347595215, 'actor_loss': 6.153298377990723, 'critic_loss': 61.24081039428711}\n",
            "training epoch 6000 training loss {'value_loss': 4.654277801513672, 'actor_loss': 3.098684787750244, 'critic_loss': 75.80091094970703}\n",
            "training epoch 6500 training loss {'value_loss': 4.773356914520264, 'actor_loss': 5.105808258056641, 'critic_loss': 244.9698486328125}\n",
            "training epoch 7000 training loss {'value_loss': 5.341713905334473, 'actor_loss': 4.7835845947265625, 'critic_loss': 126.96758270263672}\n",
            "training epoch 7500 training loss {'value_loss': 4.727663516998291, 'actor_loss': 6.566059112548828, 'critic_loss': 51.88398742675781}\n",
            "training epoch 8000 training loss {'value_loss': 5.532264709472656, 'actor_loss': 5.600111961364746, 'critic_loss': 72.99192810058594}\n",
            "training epoch 8500 training loss {'value_loss': 5.432359218597412, 'actor_loss': 4.592930793762207, 'critic_loss': 60.98143005371094}\n",
            "training epoch 9000 training loss {'value_loss': 5.523667812347412, 'actor_loss': 5.735751152038574, 'critic_loss': 433.49212646484375}\n",
            "training epoch 9500 training loss {'value_loss': 5.216910362243652, 'actor_loss': 4.710633754730225, 'critic_loss': 83.78570556640625}\n",
            "training epoch 10000 training loss {'value_loss': 5.5933661460876465, 'actor_loss': 6.106831073760986, 'critic_loss': 147.11474609375}\n",
            "training epoch 10500 training loss {'value_loss': 8.451723098754883, 'actor_loss': 1.621864676475525, 'critic_loss': 80.56407165527344}\n",
            "training epoch 11000 training loss {'value_loss': 5.930197715759277, 'actor_loss': 6.696375846862793, 'critic_loss': 51.827354431152344}\n",
            "training epoch 11500 training loss {'value_loss': 6.042293071746826, 'actor_loss': 4.567534446716309, 'critic_loss': 49.92619323730469}\n",
            "training epoch 12000 training loss {'value_loss': 6.438072204589844, 'actor_loss': 6.379092693328857, 'critic_loss': 61.6492919921875}\n",
            "training epoch 12500 training loss {'value_loss': 7.940831661224365, 'actor_loss': 8.604472160339355, 'critic_loss': 75.86102294921875}\n",
            "training epoch 13000 training loss {'value_loss': 5.591019630432129, 'actor_loss': 5.480114936828613, 'critic_loss': 48.68461990356445}\n",
            "training epoch 13500 training loss {'value_loss': 7.129138946533203, 'actor_loss': 7.606431007385254, 'critic_loss': 373.935546875}\n",
            "training epoch 14000 training loss {'value_loss': 6.827956199645996, 'actor_loss': 4.753203392028809, 'critic_loss': 53.25279235839844}\n",
            "training epoch 14500 training loss {'value_loss': 6.406070709228516, 'actor_loss': 4.576913833618164, 'critic_loss': 312.50872802734375}\n",
            "training epoch 15000 training loss {'value_loss': 5.100269317626953, 'actor_loss': 5.159261703491211, 'critic_loss': 41.06599426269531}\n",
            "training epoch 15500 training loss {'value_loss': 5.180943965911865, 'actor_loss': 3.508758306503296, 'critic_loss': 117.24136352539062}\n",
            "training epoch 16000 training loss {'value_loss': 5.463344573974609, 'actor_loss': 6.216361999511719, 'critic_loss': 43.10062026977539}\n",
            "training epoch 16500 training loss {'value_loss': 5.401617527008057, 'actor_loss': 5.747153282165527, 'critic_loss': 55.972713470458984}\n",
            "training epoch 17000 training loss {'value_loss': 3.9423627853393555, 'actor_loss': 5.501956462860107, 'critic_loss': 37.53641891479492}\n",
            "training epoch 17500 training loss {'value_loss': 5.858860015869141, 'actor_loss': 5.931163311004639, 'critic_loss': 52.11505889892578}\n",
            "training epoch 18000 training loss {'value_loss': 4.535668849945068, 'actor_loss': 6.64661979675293, 'critic_loss': 39.68230056762695}\n",
            "training epoch 18500 training loss {'value_loss': 5.359783172607422, 'actor_loss': 4.801802158355713, 'critic_loss': 48.881675720214844}\n",
            "training epoch 19000 training loss {'value_loss': 4.827009201049805, 'actor_loss': 5.660262584686279, 'critic_loss': 38.47284698486328}\n",
            "training epoch 19500 training loss {'value_loss': 5.641027450561523, 'actor_loss': 4.343751430511475, 'critic_loss': 49.17301559448242}\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Offline evaluation"
      ],
      "metadata": {
        "id": "anb1UsbspTNE"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Use the `offline_evaluation` utility function in Pearl to evaluate the trained agent by interacting with the environment.\n",
        "\n",
        "offline_evaluation_returns = offline_evaluation(\n",
        "    offline_agent=offline_agent,\n",
        "    env=env,\n",
        "    number_of_episodes=50,\n",
        "    seed=experiment_seed,\n",
        ")\n",
        "\n",
        "# mean evaluation returns of the offline agent\n",
        "avg_offline_agent_returns = torch.mean(torch.tensor(offline_evaluation_returns))\n",
        "print(f\"average returns of the offline agent {avg_offline_agent_returns}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wAu_fL6EpXfL",
        "outputId": "97718810-4cc1-4f52-af2e-45b6257aef2d"
      },
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "episode 49, return=244.23391946865013average returns of the offline agent 230.69460202311308\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "# Getting normalized scores (typical for benchmarking offline RL algorithms)\n"
      ],
      "metadata": {
        "id": "gNddjWLf1EOk"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Step 1: get episodic returns of a random agent using the 'returns_random_agent.pickle' file on github.\n",
        "# You can also do this by initializing an offline agent randomly and interacting with the environment.\n",
        "\n",
        "\n",
        "# URL of the file on GitHub\n",
        "random_returns_url = \"https://github.com/facebookresearch/Pearl/raw/gh-pages/data/offline_rl_data/HalfCheetah/returns_random_agent.pickle\"\n",
        "\n",
        "# Download the file\n",
        "response = requests.get(random_returns_url)\n",
        "returns = response.content\n",
        "\n",
        "# Load the data from the file\n",
        "with open('random_agent_returns.pkl', 'wb') as f:\n",
        "    f.write(returns)\n",
        "with open('random_agent_returns.pkl', 'rb') as f:\n",
        "    random_agent_returns = pickle.load(f)\n",
        "\n",
        "\n",
        "avg_return_random_agent = torch.mean(torch.tensor(random_agent_returns))    # mean returns of a random agent\n",
        "print(f\"average returns of a random agent {avg_return_random_agent}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "xvpt-b76zKI4",
        "outputId": "069b3730-96d0-4660-dde7-34f17338c6f3"
      },
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "average returns of a random agent -426.930167355092\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Step 2: get training returns in the data (i.e. episodic returns of the data collection agent).\n",
        "\n",
        "# The `get_data_collection_agent_returns` function computes the episodic returns from the offline data of transition tuples\n",
        "# Note 1: We implicitly assume that the offline data was collected in an episodic manner\n",
        "# Note 2: The `data_path` points to the local file with offline data. Recall that we set data_path = cwd + '/' + filename, where filename = \"offline_raw_transitions_dict_small_2.pt\"\n",
        "data_collection_agent_returns = get_data_collection_agent_returns(\n",
        "    data_path=data_path\n",
        ")\n",
        "\n",
        "# average trajectory returns in the dataset\n",
        "avg_return_data_collection_agent = torch.mean(\n",
        "    torch.tensor(data_collection_agent_returns)\n",
        ")\n",
        "print(\n",
        "    f\"average returns of the data collection agent {avg_return_data_collection_agent}\"\n",
        ")\n",
        "\n",
        "\n",
        "max_return_data_collection_agent = torch.max(torch.tensor(data_collection_agent_returns))   # maximum trajectory returns in the dataset\n",
        "print(f\"maximum returns of the data collection agent {max_return_data_collection_agent}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NkgL8dVT5okl",
        "outputId": "1b218ae1-aad5-403e-b666-5344d5b18fed"
      },
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "getting returns of the data collection agent agent\n",
            "using offline training data in /content/offline_raw_transitions_dict_small_2.pt to stitch trajectories and compute returns\n",
            "average returns of the data collection agent 490.7219223530043\n",
            "maximum returns of the data collection agent 1878.2492669075727\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# The following is one way to define the normalized score, which is a proxy for\n",
        "# how good the offline learning algorithm is as compared to the policy that was used to collect data.\n",
        "\n",
        "normalized_score = (avg_offline_agent_returns - avg_return_random_agent) / (\n",
        "    avg_return_data_collection_agent - avg_return_random_agent\n",
        ")\n",
        "\n",
        "print(f\"normalized score {normalized_score}\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "j3rVgiB76E8V",
        "outputId": "3517c80e-64cd-4b5e-9144-d254db799402"
      },
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "normalized score 0.716638448006362\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Note 1: A normalized score close to 1 or greater than 1 indicates good performance. However, note that the offline dataset we use here is very small (100,000 transition tuples), so we do not expect to see a high normalized score.\n",
        "\n",
        "Note 2: We have used average episodic returns in the offline data as a proxy for the performance of data collection agent (which is used when computing the normalized score).\n",
        "\n",
        "- An ideal way to do this would be to run the data collection agent/policy, at the end of the data collection phase, for a few episodes and take the average episodic returns.\n",
        "- This would approximate the 'best' policy used to collect data. The real test for offline learning algorithms is to be able to beat this 'best' policy."
      ],
      "metadata": {
        "id": "h5jzoRxiRGe6"
      }
    }
  ]
}